{--
    This is pass 4 of the compiler, implemented in 'pass'.
    
    We must make sure that *type* definitions are not self-referential
    directly or indirectly.
 -}

package frege.compiler.passes.TypeAlias where

import frege.Prelude hiding (<+>)

import Data.Graph (stronglyConnectedComponents tsort)

import Lib.PP (msgdoc, text, <+>)
import frege.compiler.Utilities     as U()

import Data.List as DL(partitioned)

import  Compiler.types.Positions
import  Compiler.types.QNames
import  Compiler.types.Types
-- import  Compiler.types.Kinds
import  Compiler.types.SourceDefinitions
import  Compiler.types.Symbols
import  Compiler.types.Global as G

import  Compiler.common.Errors as E()
import  Compiler.common.SymbolTable


pass = do
    g <- getST
    let (adefs,other) = partitioned isTypDcl g.sub.sourcedefs
        adeps = map aliasdep adefs
        agrps = tsort adeps
        aflat = [ a | grp <- agrps, a <- grp ]
        sdefs = [ d | a <- aflat, d <- adefs, QName.base a == DefinitionS.name d ]
        isTypDcl (TypDcl {pos}) = true
        isTypDcl _              = false
        aliasdep (TypDcl {pos, name, typ}) = (tn, filter (g.our) deps) where
                    tn = TName g.thisPack name
                    deps = collectRho typ.rho []
        aliasdep x = error "no TypDcl"
        collectRho (RhoFun _ sig rho) acc = collectRho rho (collectSigma sig acc)
        collectRho (RhoTau _ tau)     acc = collectTau tau acc
        collectSigma (ForAll _ rho) acc = collectRho rho acc
        collectTau (TVar{})   acc = acc
        collectTau (Meta _)   acc = acc
        collectTau (TApp a b) acc = collectTau a (collectTau b acc)
        collectTau (TSig s)   acc = collectSigma s acc
        collectTau (TCon{name = n}) acc = case U.nstname n g of
            Nothing -> acc
            Just tn
                | tn `elem` acc = acc
                | Just (SymA {name}) <- g.findit tn = if name `elem` acc then acc else name:acc
                | otherwise = acc   -- do not complain about unknown type constructors
        getpos tn
            | Just (SymA {pos}) <- g.findit tn = pos
            | otherwise = Position.null
        checkmutual [] = stio ()
        checkmutual [a] = stio ()
        checkmutual (a:as) = E.error (getpos a) (msgdoc ("Mutual recursive type aliases "
                                ++ joined ", " (map (flip QName.nice g) (a:as))))
        checkselfref (tn, deps)
            | tn `elem` deps = E.error (getpos tn) (msgdoc ("Self referential type alias `"
                                ++ QName.nice tn g ++ "`"))
            | otherwise = stio ()
    changeST Global.{sub <- SubSt.{sourcedefs=reverse other}}     -- no more type aliases henceforth
    foreach agrps checkmutual
    foreach adeps checkselfref
    g <- getST
    unless (g.errors > 0) do foreach sdefs transalias
    return ("type aliases", length adefs)
    
transalias :: DefinitionS -> StG ()    
transalias (d@TypDcl {pos}) = do
        g <- getST
        let tname = TName g.thisPack d.name
        case g.findit tname of
            Just sym | SymA {pos} <- sym = case d.typ.bound of
                [] -> do
                    -- type aliases may be incomplete
                    typS <- U.validSigma1 (map Tau.var d.vars) d.typ
                    typ  <- U.transSigma (ForAll [] typS.rho)
                    changeSym sym.{typ = typ.{bound=[]}}
                bound -> do
                    -- type X a b c = forall x y. ......
                    -- The bound variables x y must be distinct from the type args a b c
                    let bvars = map Tau.var bound
                        targs = map Tau.var d.vars
                        fvars = U.freeTVnames [] d.typ.rho
                        badfree = filter (`notElem` targs) (filter (`notElem` bvars) fvars)
                        bad   = DL.intersect bvars targs
                    if null bad then 
                        if null badfree then do
                            typ1 <- U.transSigma d.typ.{bound=[]}
                            bounds ← U.transBounds bound
                            changeSym sym.{typ = typ1.{bound=bounds}} 
                            pure ()
                        else
                            E.error pos (text "Type variable(s) "
                                <+> text (joined ", " badfree)
                                <+> text " are neither type args nor bound in forall"
                              )
                    else E.error pos (text "Type variable(s) " 
                                <+> text (joined ", " bad)
                                <+> text " must either be type args or bound in forall, but not both.")
            nothing -> E.fatal pos (text ("Cannot happen, type alias " ++ tname.nice g ++ " missing"))
transalias _ = return ()
